package websocket

import (
	"context"
	"errors"
	message2 "gitee.com/zackeus/go-boot/websocket/message"
	"gitee.com/zackeus/go-zero/core/logx"
	"gitee.com/zackeus/go-zero/core/threading"
	"gitee.com/zackeus/goutil/byteutil"
	"gitee.com/zackeus/goutil/strutil"
	"github.com/gorilla/websocket"
	"sync"
	"time"
)

type (
	ClientOption func(c *Client)

	// OnPingCallback ping 回调
	OnPingCallback func(client *Client, data string)
	// OnPongCallback pong 回调
	OnPongCallback func(client *Client, data string)
	// OnMessageCallback 消息接收回调
	OnMessageCallback func(client *Client, data *message2.Message)
	// OnConnectedCallback 已连接回调
	// isForce: 是否为强制登录
	OnConnectedCallback func(client *Client, isForce bool)
	// OnDisConnectedCallback 连接断开回调
	OnDisConnectedCallback func(client *Client, code message2.CloseCode, reason string, isForce bool)
	// OnErrorCallback 异常回调
	OnErrorCallback func(client *Client, err error)

	Client struct {
		ctx            context.Context
		cancel         context.CancelFunc
		once           sync.Once
		closeOnce      sync.Once
		mutex          sync.Mutex
		id             string
		key            string
		conn           *websocket.Conn
		engine         *Engine
		onPing         OnPingCallback
		onPong         OnPongCallback
		onMessage      OnMessageCallback
		onConnected    OnConnectedCallback
		onDisConnected OnDisConnectedCallback
		onError        OnErrorCallback
	}
)

func WithContextValue(key, val any) ClientOption {
	return func(c *Client) {
		c.ctx = context.WithValue(c.ctx, key, val)
	}
}

// WithOnPing ping 回调
func WithOnPing(cb OnPingCallback) ClientOption {
	return func(c *Client) {
		c.onPing = cb
	}
}

// WithOnPong pong 回调
func WithOnPong(cb OnPongCallback) ClientOption {
	return func(c *Client) {
		c.onPong = cb
	}
}

// WithOnMessage 接收消息回调
func WithOnMessage(cb OnMessageCallback) ClientOption {
	return func(c *Client) {
		c.onMessage = cb
	}
}

// WithOnConnected 连接回调
func WithOnConnected(cb OnConnectedCallback) ClientOption {
	return func(c *Client) {
		c.onConnected = cb
	}
}

// WithOnDisConnected 连接断开回调
func WithOnDisConnected(cb OnDisConnectedCallback) ClientOption {
	return func(c *Client) {
		c.onDisConnected = cb
	}
}

// WithOnError 异常回调
func WithOnError(cb OnErrorCallback) ClientOption {
	return func(c *Client) {
		c.onError = cb
	}
}

func (c *Client) Context() context.Context {
	return c.ctx
}

// ID websocket 逻辑唯一标识(例如 用户工号)
func (c *Client) ID() string {
	return c.id
}

// ClientKey websocket 识别对端设备唯一标识
func (c *Client) ClientKey() string {
	return c.key
}

func (c *Client) init(id string, options []ClientOption) {
	c.once.Do(func() {
		c.id = id
		for _, opt := range options {
			opt(c)
		}

		/* 设置从对等方读取的消息的最大大小（以字节为单位）如果消息超出限制，连接将向对等方发送一条关闭消息，并将 ErrReadLimit 返回给应用程序 */
		c.conn.SetReadLimit(c.engine.maxMessageSize)

		/* ping 消息处理 */
		c.conn.SetPingHandler(func(data string) error {
			logx.Debugf("websocket client: [%s] receive ping.", c.ID())

			/* ping 回调 */
			c.engine.submitOnFunc(func() { c.onPing(c, data) })

			/* 回应 pong */
			err := c.conn.WriteControl(int(message2.TypePong), []byte(data), time.Now().Add(c.engine.deadLineWait))
			if errors.Is(err, websocket.ErrCloseSent) {
				return nil
			}
			return err
		})
		/*  pong 消息的处理 */
		c.conn.SetPongHandler(func(data string) error {
			logx.Debugf("websocket client: [%s] receive pong.", c.ID())

			/* pong 回调 */
			c.engine.submitOnFunc(func() { c.onPong(c, data) })

			/* 延长读超时时间(pingPeriod + deadLineWait) */
			return c.conn.SetReadDeadline(time.Now().Add(c.engine.pingPeriod + c.engine.deadLineWait))
		})
		/* close 消息处理 */
		c.conn.SetCloseHandler(func(code int, text string) error {
			logx.Debugf("websocket client: [%s] receive close, code: %d, resaon: %s", c.ID(), code, text)
			msg := websocket.FormatCloseMessage(code, "")
			_ = c.conn.WriteControl(int(message2.TypeClose), msg, time.Now().Add(c.engine.deadLineWait))
			c.engine.UnRegister(c, &message2.CloseMessage{Code: message2.CloseCode(code), Value: text})
			return nil
		})

		/* 开始客户端双工的通信，接收和写入数据 */
		threading.GoSafe(c.initRead)
		threading.GoSafe(c.initWrite)
	})
}

func (c *Client) initRead() {
	defer func() {
		c.engine.UnRegister(c)
	}()

	for {
		select {
		case <-c.ctx.Done():
			return
		default:
			/* 从 conn 中读取 msg */
			mt, mv, err := c.conn.ReadMessage()
			if err != nil {
				c.onInternalError(err)
				break
			}
			v := &message2.Message{Type: message2.Type(mt), Value: byteutil.ToString(mv)}
			/* onMessage 事件 */
			c.engine.submitOnFunc(func() { c.onMessage(c, v) })
		}
	}
}

func (c *Client) initWrite() {
	/* ping 定时器 */
	pingTicker := time.NewTicker(c.engine.pingPeriod)

	defer func() {
		pingTicker.Stop()
		c.engine.UnRegister(c)
	}()

	for {
		select {
		case <-pingTicker.C:
			/* ping */
			if err := c.conn.WriteControl(int(message2.TypePing), nil, time.Now().Add(c.engine.deadLineWait)); err != nil {
				/* 异常回调 */
				c.onInternalError(err)
				return
			}
			/* 延长写超时时间(pingPeriod + deadLineWait) */
			if err := c.conn.SetWriteDeadline(time.Now().Add(c.engine.pingPeriod + c.engine.deadLineWait)); err != nil {
				c.onInternalError(err)
				return
			}
		case <-c.ctx.Done():
			return
		}
	}
}

// Send 发送消息
func (c *Client) Send(msg *message2.Message) error {
	/* 加锁 防止出现 concurrent write to websocket connection */
	/* 详见 https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency */
	c.mutex.Lock()
	defer c.mutex.Unlock()
	return c.conn.WriteMessage(int(msg.Type), strutil.ToBytes(msg.Value))
}

// Close 关闭连接
func (c *Client) Close(msg *message2.CloseMessage) (err error) {
	c.closeOnce.Do(func() {
		_ = c.conn.WriteControl(int(message2.TypeClose), msg.Format(), time.Now().Add(c.engine.deadLineWait))
		c.cancel()
		err = c.conn.Close()
	})
	return err
}

func (c *Client) onInternalError(err error) {
	if websocket.IsUnexpectedCloseError(err, int(message2.CloseNormalClosure), int(message2.CloseGoingAway), int(message2.CloseNoStatusReceived)) {
		/* 异常回调 */
		c.engine.submitOnFunc(func() { c.onError(c, err) })
	}
}
